// -*- coding: utf-8 -*-
//
// disktest - Storage tester
//
// Copyright 2020-2024 Michael Büsch <m@bues.ch>
//
// Licensed under the Apache License version 2.0
// or the MIT license, at your option.
// SPDX-License-Identifier: Apache-2.0 OR MIT
//

use anyhow as ah;
use clap::builder::ValueParser;
use clap::error::ErrorKind::{DisplayHelp, DisplayVersion};
use clap::{value_parser, Arg, ArgAction, Command};
use disktest_lib::{gen_seed_string, parsebytes, DisktestQuiet, DtStreamType};
use std::ffi::OsString;
use std::path::PathBuf;

/// Length of the generated seed.
const DEFAULT_GEN_SEED_LEN: usize = 40;

const ABOUT: &str = "\
Solid State Disk (SSD), Non-Volatile Memory Storage (NVMe), Hard Disk (HDD), USB Stick, SD-Card tester.

This program can write a cryptographically secure pseudo random stream to a disk,
read it back and verify it by comparing it to the expected stream.

Example usage:
";

#[cfg(not(target_os = "windows"))]
const EXAMPLE: &str = "\
disktest --write --verify -j0 /dev/sdc";

#[cfg(target_os = "windows")]
const EXAMPLE: &str = "\
disktest --write --verify -j0 \\\\.\\E:";

const HELP_DEVICE: &str = "\
Device node of the disk or file path to access.
";

#[cfg(not(target_os = "windows"))]
const HELP_DEVICE_OS: &str = "\
This may be the /dev/sdX or /dev/mmcblkX or similar
device node of the disk. It may also be an arbitrary path to a location in a filesystem.";

#[cfg(target_os = "windows")]
const HELP_DEVICE_OS: &str = "\
This may be a path to the location on the disk to be tested (e.g. E:\\testfile)
or a raw drive (e.g. \\\\.\\E: or \\\\.\\PhysicalDrive2).";

const HELP_WRITE: &str = "\
Write pseudo random data to the device.
If this option is not given, then disktest will operate in
verify-only mode instead, as if only --verify was given.
If both --write and --verify are specified, then the device
will first be written and then be verified with the same seed.";

const HELP_VERIFY: &str = "\
In verify-mode the disk will be read and compared to the expected pseudo
random sequence.
If both --write and --verify are specified, then the device
will first be written and then be verified with the same seed.";

const HELP_SEEK: &str = "\
Seek to the specified byte position on disk
before starting the write/verify operation. This skips the specified
amount of bytes on the disk and also fast forwards the random number generator.
";

const HELP_BYTES: &str = "\
Number of bytes to write/verify.
If not given, then the whole disk will be overwritten/verified.
";

const HELP_ALGORITHM: &str = "\
Select the random number generator algorithm.
ChaCha12 and ChaCha8 are less cryptographically secure than ChaCha20, but
faster. CRC is even faster, but not cryptographically secure at all.
";

const HELP_SEED: &str = "\
The seed to use for random number stream generation.
The seed may be any random string (e.g. a long passphrase).
If no seed is given, then a secure random seed will be generated
and also printed to the console.";

const HELP_INVERT_PATTERN: &str = "\
Invert the bit pattern generated by the random number generator.
This can be useful, if a second write/verify run with a strictly
inverted test bit pattern is desired.";

const HELP_THREADS: &str = "\
The number of CPUs to use.
The special value 0 will select the maximum number of online CPUs in the
system. If the number of threads is equal to number of CPUs it is optimal
for performance. The number of threads must be equal during corresponding
verify and write mode runs. Otherwise the verification will fail.
";

const HELP_ROUNDS: &str = "\
The number of rounds to execute the whole process.
This normally defaults to 1 to only run the write and/or verify once.
But you may specify more than one round to repeat write and/or verify
multiple times.
If --write mode is active, then different random data will be written
on each round.
The special value of 0 rounds will execute an infinite number of rounds.
";

const HELP_START_ROUND: &str = "\
Start at the specified round index. (= Skip this many rounds).
Defaults to the first round (0).
";

const HELP_QUIET: &str = "\
Quiet level:
0: Normal verboseness.
1: Reduced verboseness.
2: No informational output.
3: No warnings.
";

/// All command line arguments.
pub struct Args {
    pub device: PathBuf,
    pub write: bool,
    pub verify: bool,
    pub seek: u64,
    pub max_bytes: u64,
    pub algorithm: DtStreamType,
    pub seed: String,
    pub user_seed: bool,
    pub invert_pattern: bool,
    pub threads: usize,
    pub rounds: u64,
    pub start_round: u64,
    pub quiet: DisktestQuiet,
}

/// Parse all command line arguments and put them into a structure.
pub fn parse_args<I, T>(args: I) -> ah::Result<Args>
where
    I: IntoIterator<Item = T>,
    T: Into<OsString> + Clone,
{
    let about = ABOUT.to_string() + EXAMPLE;
    let help_device = HELP_DEVICE.to_string() + HELP_DEVICE_OS;

    let args = Command::new("disktest")
        .about(about)
        .arg(
            Arg::new("device")
                .index(1)
                .required(true)
                .value_parser(value_parser!(PathBuf))
                .help(help_device),
        )
        .arg(
            Arg::new("write")
                .long("write")
                .short('w')
                .action(ArgAction::SetTrue)
                .help(HELP_WRITE),
        )
        .arg(
            Arg::new("verify")
                .long("verify")
                .short('v')
                .action(ArgAction::SetTrue)
                .help(HELP_VERIFY),
        )
        .arg(
            Arg::new("seek")
                .long("seek")
                .short('s')
                .value_name("BYTES")
                .default_value("0")
                .value_parser(ValueParser::new(parsebytes))
                .help(HELP_SEEK),
        )
        .arg(
            Arg::new("bytes")
                .long("bytes")
                .short('b')
                .value_name("BYTES")
                .default_value("18446744073709551615")
                .value_parser(ValueParser::new(parsebytes))
                .help(HELP_BYTES),
        )
        .arg(
            Arg::new("algorithm")
                .long("algorithm")
                .short('A')
                .value_name("ALG")
                .default_value("CHACHA20")
                .value_parser(["CHACHA8", "CHACHA12", "CHACHA20", "CRC"])
                .ignore_case(true)
                .help(HELP_ALGORITHM),
        )
        .arg(
            Arg::new("seed")
                .long("seed")
                .short('S')
                .value_name("SEED")
                .help(HELP_SEED),
        )
        .arg(
            Arg::new("invert-pattern")
                .long("invert-pattern")
                .short('i')
                .action(ArgAction::SetTrue)
                .help(HELP_INVERT_PATTERN),
        )
        .arg(
            Arg::new("threads")
                .long("threads")
                .short('j')
                .value_name("NUM")
                .default_value("1")
                .value_parser(value_parser!(u32).range(0_i64..=u16::MAX as i64 + 1))
                .help(HELP_THREADS),
        )
        .arg(
            Arg::new("rounds")
                .long("rounds")
                .short('R')
                .value_name("NUM")
                .default_value("1")
                .value_parser(value_parser!(u64))
                .help(HELP_ROUNDS),
        )
        .arg(
            Arg::new("start-round")
                .long("start-round")
                .value_name("IDX")
                .default_value("0")
                .value_parser(value_parser!(u64).range(0_u64..=u64::MAX - 1))
                .help(HELP_START_ROUND),
        )
        .arg(
            Arg::new("quiet")
                .long("quiet")
                .short('q')
                .value_name("LVL")
                .default_value("0")
                .value_parser(value_parser!(u8))
                .help(HELP_QUIET),
        )
        .try_get_matches_from(args);

    let args = match args {
        Ok(x) => x,
        Err(e) => {
            match e.kind() {
                DisplayHelp | DisplayVersion => {
                    print!("{}", e);
                    std::process::exit(0);
                }
                _ => (),
            };
            return Err(ah::format_err!("{}", e));
        }
    };

    let quiet = *args.get_one::<u8>("quiet").unwrap();
    let quiet = if quiet == DisktestQuiet::Normal as u8 {
        DisktestQuiet::Normal
    } else if quiet == DisktestQuiet::Reduced as u8 {
        DisktestQuiet::Reduced
    } else if quiet == DisktestQuiet::NoInfo as u8 {
        DisktestQuiet::NoInfo
    } else {
        DisktestQuiet::NoWarn
    };

    let device = args.get_one::<PathBuf>("device").unwrap().clone();

    let write = args.get_flag("write");
    let mut verify = args.get_flag("verify");
    if !write && !verify {
        verify = true;
    }

    let seek = *args.get_one::<u64>("seek").unwrap();

    let max_bytes = *args.get_one::<u64>("bytes").unwrap();

    let algorithm = match args
        .get_one::<String>("algorithm")
        .unwrap()
        .to_ascii_uppercase()
        .as_str()
    {
        "CHACHA8" => DtStreamType::ChaCha8,
        "CHACHA12" => DtStreamType::ChaCha12,
        "CHACHA20" => DtStreamType::ChaCha20,
        "CRC" => DtStreamType::Crc,
        _ => panic!("Invalid algorithm parameter."),
    };

    let (seed, user_seed) = match args.get_one::<String>("seed") {
        Some(x) => (x.clone(), true),
        None => (gen_seed_string(DEFAULT_GEN_SEED_LEN), false),
    };
    if !user_seed && verify && !write {
        return Err(ah::format_err!(
            "Verify-only mode requires --seed. \
             Please either provide a --seed, \
             or enable --verify and --write mode."
        ));
    }

    let invert_pattern = args.get_flag("invert-pattern");

    let threads = *args.get_one::<u32>("threads").unwrap_or(&1) as usize;

    let mut rounds = *args.get_one::<u64>("rounds").unwrap_or(&1);
    if rounds == 0 {
        rounds = u64::MAX;
    }
    let start_round = *args.get_one::<u64>("start-round").unwrap_or(&0);
    if start_round >= rounds {
        rounds = start_round + 1;
    }

    Ok(Args {
        device,
        write,
        verify,
        seek,
        max_bytes,
        algorithm,
        seed,
        user_seed,
        invert_pattern,
        threads,
        rounds,
        start_round,
        quiet,
    })
}

#[cfg(test)]
mod tests {
    use super::*;
    use disktest_lib::Disktest;

    #[test]
    fn test_parse_args() {
        assert!(parse_args(vec!["disktest", "--does-not-exist"]).is_err());

        let a = parse_args(vec!["disktest", "-Sx", "/dev/foobar"]).unwrap();
        assert_eq!(a.device, PathBuf::from("/dev/foobar"));
        assert!(!a.write);
        assert!(a.verify);
        assert_eq!(a.seek, 0);
        assert_eq!(a.max_bytes, Disktest::UNLIMITED);
        assert_eq!(a.algorithm, DtStreamType::ChaCha20);
        assert_eq!(a.seed, "x");
        assert!(a.user_seed);
        assert!(!a.invert_pattern);
        assert_eq!(a.threads, 1);
        assert_eq!(a.quiet, DisktestQuiet::Normal);

        let a = parse_args(vec!["disktest", "--write", "/dev/foobar"]).unwrap();
        assert_eq!(a.device, PathBuf::from("/dev/foobar"));
        assert!(a.write);
        assert!(!a.verify);
        assert!(!a.user_seed);
        let a = parse_args(vec!["disktest", "-w", "/dev/foobar"]).unwrap();
        assert_eq!(a.device, PathBuf::from("/dev/foobar"));
        assert!(a.write);
        assert!(!a.verify);
        assert!(!a.user_seed);

        let a = parse_args(vec!["disktest", "--write", "--verify", "/dev/foobar"]).unwrap();
        assert_eq!(a.device, PathBuf::from("/dev/foobar"));
        assert!(a.write);
        assert!(a.verify);
        assert!(!a.user_seed);
        let a = parse_args(vec!["disktest", "-w", "-v", "/dev/foobar"]).unwrap();
        assert_eq!(a.device, PathBuf::from("/dev/foobar"));
        assert!(a.write);
        assert!(a.verify);
        assert!(!a.user_seed);

        let a = parse_args(vec!["disktest", "-Sx", "--verify", "/dev/foobar"]).unwrap();
        assert_eq!(a.device, PathBuf::from("/dev/foobar"));
        assert!(!a.write);
        assert!(a.verify);
        let a = parse_args(vec!["disktest", "-Sx", "-v", "/dev/foobar"]).unwrap();
        assert_eq!(a.device, PathBuf::from("/dev/foobar"));
        assert!(!a.write);
        assert!(a.verify);

        let a = parse_args(vec!["disktest", "-w", "--seek", "123", "/dev/foobar"]).unwrap();
        assert_eq!(a.seek, 123);
        let a = parse_args(vec!["disktest", "-w", "-s", "123 MiB", "/dev/foobar"]).unwrap();
        assert_eq!(a.seek, 123 * 1024 * 1024);

        let a = parse_args(vec!["disktest", "-w", "--bytes", "456", "/dev/foobar"]).unwrap();
        assert_eq!(a.max_bytes, 456);
        let a = parse_args(vec!["disktest", "-w", "-b", "456 MiB", "/dev/foobar"]).unwrap();
        assert_eq!(a.max_bytes, 456 * 1024 * 1024);

        let a = parse_args(vec![
            "disktest",
            "-w",
            "--algorithm",
            "CHACHA8",
            "/dev/foobar",
        ])
        .unwrap();
        assert_eq!(a.algorithm, DtStreamType::ChaCha8);
        let a = parse_args(vec!["disktest", "-w", "-A", "chacha8", "/dev/foobar"]).unwrap();
        assert_eq!(a.algorithm, DtStreamType::ChaCha8);
        let a = parse_args(vec!["disktest", "-w", "-A", "chacha12", "/dev/foobar"]).unwrap();
        assert_eq!(a.algorithm, DtStreamType::ChaCha12);
        let a = parse_args(vec!["disktest", "-w", "-A", "crc", "/dev/foobar"]).unwrap();
        assert_eq!(a.algorithm, DtStreamType::Crc);
        assert!(parse_args(vec!["disktest", "-w", "-A", "invalid", "/dev/foobar"]).is_err());

        let a = parse_args(vec!["disktest", "-w", "--seed", "mysecret", "/dev/foobar"]).unwrap();
        assert_eq!(a.seed, "mysecret");
        assert!(a.user_seed);
        let a = parse_args(vec!["disktest", "-w", "-S", "mysecret", "/dev/foobar"]).unwrap();
        assert_eq!(a.seed, "mysecret");
        assert!(a.user_seed);

        let a = parse_args(vec!["disktest", "-w", "--threads", "24", "/dev/foobar"]).unwrap();
        assert_eq!(a.threads, 24);
        let a = parse_args(vec!["disktest", "-w", "-j24", "/dev/foobar"]).unwrap();
        assert_eq!(a.threads, 24);
        let a = parse_args(vec!["disktest", "-w", "-j0", "/dev/foobar"]).unwrap();
        assert_eq!(a.threads, 0);
        assert!(parse_args(vec!["disktest", "-w", "-j65537", "/dev/foobar"]).is_err());

        let a = parse_args(vec!["disktest", "-w", "--quiet", "2", "/dev/foobar"]).unwrap();
        assert_eq!(a.quiet, DisktestQuiet::NoInfo);
        let a = parse_args(vec!["disktest", "-w", "-q2", "/dev/foobar"]).unwrap();
        assert_eq!(a.quiet, DisktestQuiet::NoInfo);

        let a = parse_args(vec!["disktest", "-w", "--invert-pattern", "/dev/foobar"]).unwrap();
        assert!(a.invert_pattern);
        let a = parse_args(vec!["disktest", "-w", "-i", "/dev/foobar"]).unwrap();
        assert!(a.invert_pattern);
    }
}

// vim: ts=4 sw=4 expandtab
